Conversation
- Introduced `nvte_unswizzle_scaling_factors` to convert swizzled scaling factors back to row-major format. - Implemented `regs_unshuffle_with_bit_shifts` and `regs_unshuffle` for unshuffling operations in CUDA kernels. - Added `unswizzle_row_scaling_kernel_impl` and `unswizzle_col_scaling_kernel_impl` for handling unswizzling in row and column scaling respectively. These changes enhance the functionality of the swizzle module, enabling better handling of scaling factors in tensor operations. Signed-off-by: Abhishek <abhi.dtu11@gmail.com>
These enhancements tests the changes introduced for unswizzling Signed-off-by: Abhishek <abhi.dtu11@gmail.com>
- Introduced `compute_ref_unswizzle` to handle the conversion of swizzled scaling factors back to their original format. - Added `performTestUnswizzle1D` to validate the unswizzling process with various scaling modes. - Created `UnswizzleTestSuite` for comprehensive testing of unswizzling operations. Signed-off-by: Abhishek <abhi.dtu11@gmail.com>
- Moved the definition of `swizzle_row_scaling_kernel` to a new location for better organization. - Ensured the kernel implementation is now properly defined and accessible for scaling operations in the swizzle module. Signed-off-by: Abhishek <abhi.dtu11@gmail.com>
- Introduced `multi_tensor_unswizzle_scaling_factors` to convert swizzled scaling factors back to their original row-major format. - Implemented CUDA kernels for unswizzling in both row and column scaling, enhancing the swizzle module's functionality. - Updated the launch function to handle multiple tensor unswizzling operations efficiently. These changes improve the handling of scaling factors in tensor operations, ensuring better performance and organization within the swizzle module. Signed-off-by: Abhishek <abhi.dtu11@gmail.com>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR adds unswizzle support for MXFP8 / NVFP4 scaling factors, implementing the inverse of the existing swizzle operation. It introduces new CUDA kernels (
Confidence Score: 3/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant Caller
participant nvte_unswizzle_scaling_factors
participant unswizzle_scaling_factors
participant unswizzle_row_scaling_kernel
participant unswizzle_col_scaling_kernel
Caller->>nvte_unswizzle_scaling_factors: (swizzled_tensor, output_tensor, stream)
nvte_unswizzle_scaling_factors->>unswizzle_scaling_factors: dispatch
unswizzle_scaling_factors->>unswizzle_scaling_factors: validate scaling_mode, swizzled flag,\nm%128==0, k%4==0, output size == m*k
alt rowwise_unswizzle
unswizzle_scaling_factors->>unswizzle_row_scaling_kernel: <<<(DIVUP(tiles_k,n_tb), tiles_m)>>>\n(swizzled_ptr, compact_ptr, m, k)
note over unswizzle_row_scaling_kernel: 1) Linear load: swizzled global → SLM\n2) __syncthreads\n3) SLM tile → regs (swizzle index)\n4) regs_unshuffle<LType>\n5) Write regs → compact global
else columnwise_unswizzle
unswizzle_scaling_factors->>unswizzle_col_scaling_kernel: <<<(DIVUP(tiles_k,TB_DIM), DIVUP(tiles_m,vls))>>>\n(swizzled_ptr, compact_ptr, m, k)
note over unswizzle_col_scaling_kernel: 1) Linear load: swizzled global → SLM\n2) __syncthreads\n3) SLM (swizzle index) → regs\n4) regs_unshuffle_with_bit_shifts\n5) Write regs → compact global (M-major)
end
unswizzle_scaling_factors-->>Caller: compact scale_inv written to output
Last reviewed commit: 4410e9d |
Signed-off-by: Abhishek <abhi.dtu11@gmail.com>
85ea04b to
17dbb33
Compare
for more information, see https://pre-commit.ci
…ather than casting Signed-off-by: Abhishek <abhi.dtu11@gmail.com>
for more information, see https://pre-commit.ci
|
@int-smart Please address the comments from Greptile and ideally also add the test case with the input not already padded to 128,128. |
|
@ptrendx Will look into these |
|
@ptrendx From what I am understanding then, there is no relevance of padding to the unswizzle kernel. Since the padding is already done during the swizzling operation I can just mirror it back to compact layout with the zero pads correctly in the compact layout and that should do. Is that assumption correct. Initially I was thinking of removing the padding from the scale_inv itself since this would be used for checkpointing |
- Updated unswizzling kernel implementations to remove original_M and original_K parameters, simplifying the function signatures. - Enhanced test suite to utilize new unswizzling data shapes, ensuring comprehensive coverage of aligned and padded cases. These changes improve the clarity and efficiency of the unswizzling process in the swizzle module. Signed-off-by: Abhishek <abhi.dtu11@gmail.com>
for more information, see https://pre-commit.ci
| const bool has_rowwise_scale_inv = input->scale_inv.has_data(); | ||
| const bool has_columnwise_scale_inv = input->columnwise_scale_inv.has_data(); | ||
| NVTE_CHECK(!has_rowwise_scale_inv || !has_columnwise_scale_inv, | ||
| "Input tensor has both row-wise and column-wise scaling factors"); |
There was a problem hiding this comment.
Asymmetric handling of dual-scale tensors breaks round-trip correctness
unswizzle_scaling_factors explicitly rejects tensors that have both rowwise and columnwise scaling factors (line 1165–1166), but the counterpart swizzle_scaling_factors happily processes both scale types in a single call (it runs both the rowwise and columnwise swizzle paths sequentially).
This means calling the public round-trip pair —
nvte_swizzle_scaling_factors(input, swizzled, stream); // succeeds: handles both scales
nvte_unswizzle_scaling_factors(swizzled, output, stream); // FAILS: "Input tensor has both..."— will raise a runtime error for any MXFP8 tensor that carries both rowwise and columnwise scale factors (a common configuration in dual-path training).
The same asymmetry is present in the multi-tensor variant (multi_tensor_unswizzle_scaling_factors, line 1391–1392).
The fix is either:
- Support both scale types in the unswizzle path (mirror
swizzle_scaling_factors), or - Document the restriction in the header API comment so callers know to split the tensor or call two separate unswizzle invocations.
As-is, a user who relies on swizzle ↔ unswizzle being a perfect inverse pair for the general case will encounter a silent API contract violation.
Description
This PR adds unswizzle support for scaling factors and extends the swizzle module so scaling tensors can be converted from GEMM-swizzled layout back to compact layout, including multi-tensor paths. It also adds round-trip and standalone tests to validate unswizzle correctness.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
transformer_engine/common/swizzle/swizzle.cuand declarations intransformer_engine/common/include/transformer_engine/swizzle.htests/cpp/operator/test_swizzle.cu, including standalone unswizzle and swizzle→unswizzle round-trip coverageChecklist: